package main

import (
	"os"
	"os/exec"
	"path"
	"slices"
	"strings"
	"text/template"

	"fmt"

	"github.com/gomlx/gomlx/internal/backendparser"
	"github.com/gomlx/gomlx/internal/must"
	"github.com/gomlx/gomlx/pkg/support/sets"
	"k8s.io/klog/v2"
)

const (
	standardOpsInterfaceFile = "gen_standard_ops.go"
)

var (
	// methodsNotGenerated because they are maintained manually.
	methodsNotGenerated = sets.MakeWith(
		"Constant", "Parameter", "Identity", "ReduceWindow",
		"BatchNormForInference", "BatchNormForTraining", "BatchNormGradient",
		"And", "Or", "Xor", "Not", "ReduceAnd", "ReduceOr", "ReduceXor", "ScatterAdd",
		"AllReduce", // Output is not standard
	)

	// methodsExcluded from generating and even from having a NodeType.
	// These are utility methods, not part of building a graph.
	methodsExcluded = sets.MakeWith(
		"Name", "Compile", "OpShape")

	standardOpsTemplate = template.Must(template.New(standardOpsInterfaceFile).Parse(
		`/***** File generated by ./internal/cmd/notimplemented_generator, based on github.com/gomlx/gomlx/backends/. Don't edit it directly. *****/

package notimplemented

import (
	"github.com/gomlx/gomlx/backends"
	"github.com/gomlx/gomlx/pkg/core/shapes"
	"github.com/gomlx/gopjrt/dtypes"
)

{{- range .}}
{{- range .Comments}}
{{.}}
{{- end}}
func (b Builder) {{.Name}}({{range .Parameters}}{{.Name}} {{.Type}},{{end}}) (backends.Op, error) {
	return nil, b.baseErrFn(backends.OpType{{.Name}})
}
{{end}}
`))
)

// GenerateStandardOpsInterface generates the interface for the various standard ops.
// The rest of the ops are maintained manually.
func GenerateStandardOpsInterface(methods []backendparser.Method) {
	newMethods := make([]backendparser.Method, 0, len(methods))
	for _, method := range methods {
		if methodsExcluded.Has(method.Name) || methodsNotGenerated.Has(method.Name) {
			continue
		}
		if len(method.Outputs) != 2 || method.Outputs[0].Type != "Op" || method.Outputs[1].Type != "error" {
			// Non-conventional op, skipping.
			continue
		}
		for i := range method.Parameters {
			pi := &method.Parameters[i]
			if pi.Type == "Op" {
				pi.Type = "backends.Op"
			} else if pi.Type == "...Op" {
				pi.Type = "...backends.Op"
			} else if pi.Type == "[]Op" {
				pi.Type = "[]backends.Op"
			} else if pi.Type == "Shape" {
				pi.Type = "shapes.Shape"
			} else if pi.Type == "xla_data.FftType" || pi.Type == "FFTType" {
				pi.Type = "backends.FFTType"
			} else if pi.Type == "ConvolveAxesConfig" {
				pi.Type = "backends.ConvolveAxesConfig"
			} else if pi.Type == "...PadAxis" {
				pi.Type = "...backends.PadAxis"
			}
		}
		newMethods = append(newMethods, method)
	}
	methods = newMethods
	slices.SortFunc(methods, func(a, b backendparser.Method) int { return strings.Compare(a.Name, b.Name) })

	fileName := path.Join(must.M1(os.Getwd()), standardOpsInterfaceFile)
	f := must.M1(os.Create(fileName))
	must.M(standardOpsTemplate.Execute(f, methods))
	cmd := exec.Command("go", "fmt", fileName)
	klog.V(1).Infof("\t%s\n", cmd)
	must.M(cmd.Run())
	fmt.Printf("✅ notimplemented_generator:\tsuccessfully generated %s\n", fileName)
}
